import numpy as np
import random
import torch

# 定义数据字典
dic_x = '<SOS>,<EOS>,<PAD>,0,1,2,3,4,5,6,7,8,9,q,w,e,r,t,y,u,i,o,p,a,s,d,f,g,h,j,k,l,z,x,c,v,b,n,m'
dic_x = {word: i for i, word in enumerate(dic_x.split(','))}
dic_xr = [k for k, _ in dic_x.items()]
# x中的值大写
dic_y = {k.upper(): v for k, v in dic_x.items()}
dic_yr = [k for k, _ in dic_y.items()]


def get_data():
    # 定义词集合
    words = [
        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
        'q', 'w', 'e', 'r', 't', 'y', 'u', 'i', 'o', 'p',
        'a', 's', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'z',
        'x', 'c', 'v', 'b', 'n', 'm'
    ]
    # 定义每个词被选中的概率
    p = np.array([
        1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
        21, 22, 23, 24, 25, 26
    ])
    p = p / p.sum()

    # 随机选n个词
    n = random.randint(30, 48)
    x = np.random.choice(words, size=n, replace=True, p=p)
    # 随机采样的结果
    x = x.tolist()

    def f(x):
        # y是对x的变换得到的,字母大写,数字取10以内的互补数
        x = x.upper()
        if not x.isdigit():
            # 不是数字就直接返回
            return x
        x = 9 - int(x)
        return str(x)
    y = [f(i) for i in x]
    y = y + [y[-1]]
    # 逆序
    y = y[::-1]
    # 加上首尾符号
    x = ['<SOS>'] + x + ['<EOS>']
    y = ['<SOS>'] + y + ['<EOS>']
    # 补pad到固定长度
    x = x + ['<PAD>'] * 50
    y = y + ['<PAD>'] * 50
    # 取前50
    x = x[:50]
    # 取前51
    y = y[:51]
    # 编码成数据
    x = [dic_x[i] for i in x]
    y = [dic_y[i] for i in y]
    # 转成tensor
    x = torch.LongTensor(x)
    y = torch.LongTensor(y)
    return x, y


class Dataset(torch.utils.data.Dataset):
    """
    定义数据集
    """

    def __init__(self):
        super(Dataset, self).__init__()

    def __len__(self):
        return 100000

    def __getitem__(self, i):
        return get_data()


loader = torch.utils.data.DataLoader(dataset=Dataset(),
                                     batch_size=8,
                                     drop_last=True,
                                     shuffle=True,
                                     collate_fn=None)
